{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9a4a87af62afb0c",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Langсhain\n",
    "\n",
    "## What is Langchain?\n",
    "\n",
    "[Langchain](https://github.com/langchain-ai/langchain) stands out as a leading AI framework, renowned for its unique approach to \"Constructing applications using LLMs via composability.\"\n",
    "\n",
    "But, while LangChain facilitates orchestration, it doesn't directly handle LLM security. That's where LLM Guard comes into play.\n",
    "\n",
    "## What is LCEL?\n",
    "\n",
    "[LangChain Expression Language or LCEL](https://python.langchain.com/docs/expression_language/) is a declarative way to easily compose chains together.\n",
    "\n",
    "We can chain LLM Guard and the LLM sequentially. This means that we check if LLM Guard has identified any security risk in the prompt before it is sent to the LLM to get an output.\n",
    "\n",
    "And then use another scanner to check if the output from the LLM is safe to be sent to the user.\n",
    "\n",
    "In [examples/langchain.py](https://github.com/protectai/llm-guard/blob/main/examples/langchain.py), you can find an example of how to use LCEL to compose LLM Guard chains.\n",
    "\n",
    "----"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bfd1c8728de58d3",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Let's try to integrate LLM Guard with Langchain. \n",
    "\n",
    "Start by installing the dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ded389c35d76ed07",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "!pip install llm-guard langchain openai"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7c5c5b17bf06901",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "In case, you need faster inference, use ONNX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3499ff7dc1e3e01",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "!pip install llm-guard[onnxruntime]\n",
    "!pip install llm-guard[onnxruntime-gpu]\n",
    "\n",
    "use_onnx = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b6f3dddc0bc15f",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "However, we won't use it in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bcf1075118e81cf",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "use_onnx = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38a04cf18260144e",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Before we start, we need to set O API key. In order to get it, go to https://platform.openai.com/api-keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90d1e57fd9bd0c21",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "openai_api_key = \"sk-your-key\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e419a06f8fcc02c",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Then, we can create prompt scanner that uses `Chain` from `langchain`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4285ef1fe44e79a6",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from typing import Any, Dict, List, Optional, Union\n",
    "\n",
    "from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun\n",
    "from langchain.chains.base import Chain\n",
    "from langchain.pydantic_v1 import BaseModel, root_validator\n",
    "from langchain.schema.messages import BaseMessage\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "try:\n",
    "    import llm_guard\n",
    "except ImportError:\n",
    "    raise ModuleNotFoundError(\n",
    "        \"Could not import llm-guard python package. \"\n",
    "        \"Please install it with `pip install llm-guard`.\"\n",
    "    )\n",
    "\n",
    "\n",
    "class LLMGuardPromptException(Exception):\n",
    "    \"\"\"Exception to raise when llm-guard marks prompt invalid.\"\"\"\n",
    "\n",
    "\n",
    "class LLMGuardPromptChain(Chain):\n",
    "    scanners: Dict[str, Dict] = {}\n",
    "    \"\"\"The scanners to use.\"\"\"\n",
    "    scanners_ignore_errors: List[str] = []\n",
    "    \"\"\"The scanners to ignore if they throw errors.\"\"\"\n",
    "    vault: Optional[llm_guard.vault.Vault] = None\n",
    "    \"\"\"The scanners to ignore errors from.\"\"\"\n",
    "    raise_error: bool = True\n",
    "    \"\"\"Whether to raise an error if the LLMGuard marks the prompt invalid.\"\"\"\n",
    "\n",
    "    input_key: str = \"input\"  #: :meta private:\n",
    "    output_key: str = \"sanitized_input\"  #: :meta private:\n",
    "    initialized_scanners: List[Any] = []  #: :meta private:\n",
    "\n",
    "    @root_validator(pre=True)\n",
    "    def init_scanners(cls, values: Dict[str, Any]) -> Dict[str, Any]:\n",
    "        \"\"\"\n",
    "        Initializes scanners\n",
    "\n",
    "        Args:\n",
    "            values (Dict[str, Any]): A dictionary containing configuration values.\n",
    "\n",
    "        Returns:\n",
    "            Dict[str, Any]: A dictionary with the updated configuration values,\n",
    "                            including the initialized scanners.\n",
    "\n",
    "        Raises:\n",
    "            ValueError: If there is an issue importing 'llm-guard' or loading scanners.\n",
    "        \"\"\"\n",
    "\n",
    "        if values.get(\"initialized_scanners\") is not None:\n",
    "            return values\n",
    "        try:\n",
    "            if values.get(\"scanners\") is not None:\n",
    "                values[\"initialized_scanners\"] = []\n",
    "                for scanner_name in values.get(\"scanners\"):\n",
    "                    scanner_config = values.get(\"scanners\")[scanner_name]\n",
    "                    if scanner_name == \"Anonymize\":\n",
    "                        scanner_config[\"vault\"] = values[\"vault\"]\n",
    "\n",
    "                    values[\"initialized_scanners\"].append(\n",
    "                        llm_guard.input_scanners.get_scanner_by_name(scanner_name, scanner_config)\n",
    "                    )\n",
    "\n",
    "            return values\n",
    "        except Exception as e:\n",
    "            raise ValueError(\n",
    "                \"Could not initialize scanners. \" f\"Please check provided configuration. {e}\"\n",
    "            ) from e\n",
    "\n",
    "    @property\n",
    "    def input_keys(self) -> List[str]:\n",
    "        \"\"\"\n",
    "        Returns a list of input keys expected by the prompt.\n",
    "\n",
    "        This method defines the input keys that the prompt expects in order to perform\n",
    "        its processing. It ensures that the specified keys are available for providing\n",
    "        input to the prompt.\n",
    "\n",
    "        Returns:\n",
    "           List[str]: A list of input keys.\n",
    "\n",
    "        Note:\n",
    "           This method is considered private and may not be intended for direct\n",
    "           external use.\n",
    "        \"\"\"\n",
    "        return [self.input_key]\n",
    "\n",
    "    @property\n",
    "    def output_keys(self) -> List[str]:\n",
    "        \"\"\"\n",
    "        Returns a list of output keys.\n",
    "\n",
    "        This method defines the output keys that will be used to access the output\n",
    "        values produced by the chain or function. It ensures that the specified keys\n",
    "        are available to access the outputs.\n",
    "\n",
    "        Returns:\n",
    "            List[str]: A list of output keys.\n",
    "\n",
    "        Note:\n",
    "            This method is considered private and may not be intended for direct\n",
    "            external use.\n",
    "\n",
    "        \"\"\"\n",
    "        return [self.output_key]\n",
    "\n",
    "    def _check_result(\n",
    "        self,\n",
    "        scanner_name: str,\n",
    "        is_valid: bool,\n",
    "        risk_score: float,\n",
    "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
    "    ):\n",
    "        if is_valid:\n",
    "            return  # prompt is valid, keep scanning\n",
    "\n",
    "        if run_manager:\n",
    "            run_manager.on_text(\n",
    "                text=f\"This prompt was determined as invalid by {scanner_name} scanner with risk score {risk_score}\",\n",
    "                color=\"red\",\n",
    "                verbose=self.verbose,\n",
    "            )\n",
    "\n",
    "        if scanner_name in self.scanners_ignore_errors:\n",
    "            return  # ignore error, keep scanning\n",
    "\n",
    "        if self.raise_error:\n",
    "            raise LLMGuardPromptException(\n",
    "                f\"This prompt was determined as invalid based on configured policies with risk score {risk_score}\"\n",
    "            )\n",
    "\n",
    "    async def _acall(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        raise NotImplementedError(\"Async not implemented yet\")\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        inputs: Dict[str, str],\n",
    "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        \"\"\"\n",
    "        Executes the scanning process on the prompt and returns the sanitized prompt.\n",
    "\n",
    "        This internal method performs the scanning process on the prompt. It uses the\n",
    "        provided scanners to scan the prompt and then returns the sanitized prompt.\n",
    "        Additionally, it provides the option to log information about the run using\n",
    "        the provided `run_manager`.\n",
    "\n",
    "        Args:\n",
    "            inputs: A dictionary containing input values\n",
    "            run_manager: A run manager to handle run-related events. Default is None\n",
    "\n",
    "        Returns:\n",
    "            Dict[str, str]: A dictionary containing the processed output.\n",
    "\n",
    "        Raises:\n",
    "            LLMGuardPromptException: If there is an error during the scanning process\n",
    "        \"\"\"\n",
    "        if run_manager:\n",
    "            run_manager.on_text(\"Running LLMGuardPromptChain...\\n\")\n",
    "\n",
    "        sanitized_prompt = inputs[self.input_keys[0]]\n",
    "        for scanner in self.initialized_scanners:\n",
    "            sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)\n",
    "            self._check_result(type(scanner).__name__, is_valid, risk_score, run_manager)\n",
    "\n",
    "        return {self.output_key: sanitized_prompt}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1656799a71acfb29",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Once it's done, we can configure that scanner:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f8292757233e9b8",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "vault = llm_guard.vault.Vault()\n",
    "\n",
    "llm_guard_prompt_scanner = LLMGuardPromptChain(\n",
    "    vault=vault,\n",
    "    scanners={\n",
    "        \"Anonymize\": {\"use_faker\": True, \"use_onnx\": use_onnx},\n",
    "        \"BanSubstrings\": {\n",
    "            \"substrings\": [\"Laiyer\"],\n",
    "            \"match_type\": \"word\",\n",
    "            \"case_sensitive\": False,\n",
    "            \"redact\": True,\n",
    "        },\n",
    "        \"BanTopics\": {\"topics\": [\"violence\"], \"threshold\": 0.7, \"use_onnx\": use_onnx},\n",
    "        \"Code\": {\"denied\": [\"go\"], \"use_onnx\": use_onnx},\n",
    "        \"Language\": {\"valid_languages\": [\"en\"], \"use_onnx\": use_onnx},\n",
    "        \"PromptInjection\": {\"threshold\": 0.95, \"use_onnx\": use_onnx},\n",
    "        \"Regex\": {\"patterns\": [\"Bearer [A-Za-z0-9-._~+/]+\"]},\n",
    "        \"Secrets\": {\"redact_mode\": \"all\"},\n",
    "        \"Sentiment\": {\"threshold\": -0.05},\n",
    "        \"TokenLimit\": {\"limit\": 4096},\n",
    "        \"Toxicity\": {\"threshold\": 0.8, \"use_onnx\": use_onnx},\n",
    "    },\n",
    "    scanners_ignore_errors=[\n",
    "        \"Anonymize\",\n",
    "        \"BanSubstrings\",\n",
    "        \"Regex\",\n",
    "        \"Secrets\",\n",
    "        \"TokenLimit\",\n",
    "        \"PromptInjection\",\n",
    "    ],  # These scanners redact, so I can skip them from failing the prompt\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31e93b91926592ee",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Once it's configured, we can try to guard the chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faaeb7a4733166cf",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate\n",
    "from langchain.schema.messages import SystemMessage\n",
    "from langchain.schema.output_parser import StrOutputParser\n",
    "\n",
    "llm = ChatOpenAI(openai_api_key=openai_api_key, model_name=\"gpt-3.5-turbo-1106\")\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        SystemMessage(\n",
    "            content=\"You are a helpful assistant, which creates the best SQL queries based on my command\"\n",
    "        ),\n",
    "        HumanMessagePromptTemplate.from_template(\"{sanitized_input}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "input_prompt = \"Make an SQL insert statement to add a new user to our database. Name is John Doe. Email is test@test.com \"\n",
    "\"but also possible to contact him with hello@test.com email. Phone number is 555-123-4567 and \"\n",
    "\"the IP address is 192.168.1.100. And credit card number is 4567-8901-2345-6789. \"\n",
    "\"He works in Test LLC.\"\n",
    "guarded_chain = (\n",
    "    llm_guard_prompt_scanner  # scan input here\n",
    "    | prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")\n",
    "\n",
    "result = guarded_chain.invoke(\n",
    "    {\n",
    "        \"input\": input_prompt,\n",
    "    }\n",
    ")\n",
    "\n",
    "print(\"Result: \" + result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a96da0aa15c06f20",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Now let's guard output as well. We need to start with configuring the chain\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b025c29a55576530",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class LLMGuardOutputException(Exception):\n",
    "    \"\"\"Exception to raise when llm-guard marks output invalid.\"\"\"\n",
    "\n",
    "\n",
    "class LLMGuardOutputChain(BaseModel):\n",
    "    class Config:\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "    scanners: Dict[str, Dict] = {}\n",
    "    \"\"\"The scanners to use.\"\"\"\n",
    "    scanners_ignore_errors: List[str] = []\n",
    "    \"\"\"The scanners to ignore if they throw errors.\"\"\"\n",
    "    vault: Optional[llm_guard.vault.Vault] = None\n",
    "    \"\"\"The scanners to ignore errors from.\"\"\"\n",
    "    raise_error: bool = True\n",
    "    \"\"\"Whether to raise an error if the LLMGuard marks the output invalid.\"\"\"\n",
    "\n",
    "    initialized_scanners: List[Any] = []  #: :meta private:\n",
    "\n",
    "    @root_validator(pre=True)\n",
    "    def init_scanners(cls, values: Dict[str, Any]) -> Dict[str, Any]:\n",
    "        \"\"\"\n",
    "        Initializes scanners\n",
    "\n",
    "        Args:\n",
    "            values (Dict[str, Any]): A dictionary containing configuration values.\n",
    "\n",
    "        Returns:\n",
    "            Dict[str, Any]: A dictionary with the updated configuration values,\n",
    "                            including the initialized scanners.\n",
    "\n",
    "        Raises:\n",
    "            ValueError: If there is an issue importing 'llm-guard' or loading scanners.\n",
    "        \"\"\"\n",
    "\n",
    "        if values.get(\"initialized_scanners\") is not None:\n",
    "            return values\n",
    "        try:\n",
    "            if values.get(\"scanners\") is not None:\n",
    "                values[\"initialized_scanners\"] = []\n",
    "                for scanner_name in values.get(\"scanners\"):\n",
    "                    scanner_config = values.get(\"scanners\")[scanner_name]\n",
    "                    if scanner_name == \"Deanonymize\":\n",
    "                        scanner_config[\"vault\"] = values[\"vault\"]\n",
    "\n",
    "                    values[\"initialized_scanners\"].append(\n",
    "                        llm_guard.output_scanners.get_scanner_by_name(scanner_name, scanner_config)\n",
    "                    )\n",
    "\n",
    "            return values\n",
    "        except Exception as e:\n",
    "            raise ValueError(\n",
    "                \"Could not initialize scanners. \" f\"Please check provided configuration. {e}\"\n",
    "            ) from e\n",
    "\n",
    "    def _check_result(\n",
    "        self,\n",
    "        scanner_name: str,\n",
    "        is_valid: bool,\n",
    "        risk_score: float,\n",
    "    ):\n",
    "        if is_valid:\n",
    "            return  # prompt is valid, keep scanning\n",
    "\n",
    "        logger.warning(\n",
    "            f\"This output was determined as invalid by {scanner_name} scanner with risk score {risk_score}\"\n",
    "        )\n",
    "\n",
    "        if scanner_name in self.scanners_ignore_errors:\n",
    "            return  # ignore error, keep scanning\n",
    "\n",
    "        if self.raise_error:\n",
    "            raise LLMGuardOutputException(\n",
    "                f\"This output was determined as invalid based on configured policies with risk score {risk_score}\"\n",
    "            )\n",
    "\n",
    "    def scan(\n",
    "        self,\n",
    "        prompt: str,\n",
    "        output: Union[BaseMessage, str],\n",
    "    ) -> Union[BaseMessage, str]:\n",
    "        sanitized_output = output\n",
    "        if isinstance(output, BaseMessage):\n",
    "            sanitized_output = sanitized_output.content\n",
    "\n",
    "        for scanner in self.initialized_scanners:\n",
    "            sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)\n",
    "            self._check_result(type(scanner).__name__, is_valid, risk_score)\n",
    "\n",
    "        if isinstance(output, BaseMessage):\n",
    "            output.content = sanitized_output\n",
    "            return output\n",
    "\n",
    "        return sanitized_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c837eed450b965b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Then we need to configure the scanners:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bf43ba120b7dd43",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "llm_guard_output_scanner = LLMGuardOutputChain(\n",
    "    vault=vault,\n",
    "    scanners={\n",
    "        \"BanSubstrings\": {\n",
    "            \"substrings\": [\"Laiyer\"],\n",
    "            \"match_type\": \"word\",\n",
    "            \"case_sensitive\": False,\n",
    "            \"redact\": True,\n",
    "        },\n",
    "        \"BanTopics\": {\"topics\": [\"violence\"], \"threshold\": 0.7, \"use_onnx\": use_onnx},\n",
    "        \"Bias\": {\"threshold\": 0.75, \"use_onnx\": use_onnx},\n",
    "        \"Code\": {\"denied\": [\"go\"], \"use_onnx\": use_onnx},\n",
    "        \"Deanonymize\": {},\n",
    "        \"FactualConsistency\": {\"minimum_score\": 0.5, \"use_onnx\": use_onnx},\n",
    "        \"JSON\": {\"required_elements\": 0, \"repair\": True},\n",
    "        \"Language\": {\n",
    "            \"valid_languages\": [\"en\"],\n",
    "            \"threshold\": 0.5,\n",
    "            \"use_onnx\": use_onnx,\n",
    "        },\n",
    "        \"LanguageSame\": {\"use_onnx\": use_onnx},\n",
    "        \"MaliciousURLs\": {\"threshold\": 0.75, \"use_onnx\": use_onnx},\n",
    "        \"NoRefusal\": {\"threshold\": 0.5, \"use_onnx\": use_onnx},\n",
    "        \"Regex\": {\n",
    "            \"patterns\": [\"Bearer [A-Za-z0-9-._~+/]+\"],\n",
    "        },\n",
    "        \"Relevance\": {\"threshold\": 0.5, \"use_onnx\": use_onnx},\n",
    "        \"Sensitive\": {\"redact\": False, \"use_onnx\": use_onnx},\n",
    "        \"Sentiment\": {\"threshold\": -0.05},\n",
    "        \"Toxicity\": {\"threshold\": 0.7, \"use_onnx\": use_onnx},\n",
    "    },\n",
    "    scanners_ignore_errors=[\"BanSubstrings\", \"Regex\", \"Sensitive\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c13a4b6d9270014",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Once we have both prompt and output scanners, we can guard our chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2952a29589a22a74",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "guarded_chain = (\n",
    "    llm_guard_prompt_scanner  # scan input here\n",
    "    | prompt\n",
    "    | llm\n",
    "    | (\n",
    "        lambda ai_message: llm_guard_output_scanner.scan(input_prompt, ai_message)\n",
    "    )  # scan output here and deanonymize\n",
    "    | StrOutputParser()\n",
    ")\n",
    "\n",
    "result = guarded_chain.invoke(\n",
    "    {\n",
    "        \"input\": input_prompt,\n",
    "    }\n",
    ")\n",
    "\n",
    "print(\"Result: \" + result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
